{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Invoke SageMaker Autopilot Model from Athena"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Machine Learning (ML) with Amazon Athena (Preview) lets you use Athena to write SQL statements that run Machine Learning (ML) inference using Amazon SageMaker. This feature simplifies access to ML models for data analysis, eliminating the need to use complex programming methods to run inference.\n",
    "\n",
    "To use ML with Athena (Preview), you define an ML with Athena (Preview) function with the `USING FUNCTION` clause. The function points to the Amazon SageMaker model endpoint that you want to use and specifies the variable names and data types to pass to the model. Subsequent clauses in the query reference the function to pass values to the model. The model runs inference based on the values that the query passes and then returns inference results."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/athena_model.png\" width=\"50%\" align=\"left\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pre-Requisite\n",
    "\n",
    "## *Please note that ML with Athena is in Preview and will only work in the following regions that support Preview Functionality:*\n",
    "\n",
    "## *us-east-1,  us-west-2, ap-south-1, eu-west-1*\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check if you current regions supports AthenaML Preview"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "sm = boto3.Session().client(service_name=\"sagemaker\", region_name=region)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if region in [\"eu-west-1\", \"ap-south-1\", \"us-east-1\", \"us-west-2\"]:\n",
    "    print(\" [OK] AthenaML IS SUPPORTED IN {}\".format(region))\n",
    "    print(\" [OK] Please proceed with this notebook.\")\n",
    "else:\n",
    "    print(\"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\" [ERROR] AthenaML IS *NOT* SUPPORTED IN {} !!\".format(region))\n",
    "    print(\" [INFO] This is OK. SKIP this notebook and move ahead with the workshop.\")\n",
    "    print(\" [INFO] This notebook is not required for the rest of this workshop.\")\n",
    "    print(\"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pre-Requisite\n",
    "\n",
    "## _Please wait for the Autopilot Model to deploy!!  Otherwise, this notebook won't work properly._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r autopilot_endpoint_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    autopilot_endpoint_name\n",
    "    print(\"[OK]\")\n",
    "except NameError:\n",
    "    print(\"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] There is no Autopilot Model Endpoint deployed.\")\n",
    "    print(\"[INFO] This is OK. Just skip this notebook and move ahead with the next notebook.\")\n",
    "    print(\"[INFO] This notebook is not required for the rest of this workshop.\")\n",
    "    print(\"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(autopilot_endpoint_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    resp = sm.describe_endpoint(EndpointName=autopilot_endpoint_name)\n",
    "    status = resp[\"EndpointStatus\"]\n",
    "    if status == \"InService\":\n",
    "        print(\"[OK] Your Autopilot Model Endpoint is in status: {}\".format(status))\n",
    "    elif status == \"Creating\":\n",
    "        print(\"[INFO] Your Autopilot Model Endpoint is in status: {}\".format(status))\n",
    "        print(\"[INFO] Waiting for the endpoint to be InService. Please be patient. This might take a few minutes.\")\n",
    "        sm.get_waiter(\"endpoint_in_service\").wait(EndpointName=autopilot_endpoint_name)\n",
    "    else:\n",
    "        print(\"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "        print(\"[ERROR] Your Autopilot Model is in status: {}\".format(status))\n",
    "        print(\"[INFO] This is OK. Just skip this notebook and move ahead with the next notebook.\")\n",
    "        print(\"[INFO] This notebook is not required for the rest of this workshop.\")\n",
    "        print(\"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "except:\n",
    "    print(\"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] There is no Autopilot Model Endpoint deployed.\")\n",
    "    print(\"[INFO] This is OK. Just skip this notebook and move ahead with the next notebook.\")\n",
    "    print(\"[INFO] This notebook is not required for the rest of this workshop.\")\n",
    "    print(\"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import PyAthena"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyathena import connect"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create an Athena Table with Sample Reviews"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check for Athena TSV Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r ingest_create_athena_table_tsv_passed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    ingest_create_athena_table_tsv_passed\n",
    "except NameError:\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] YOU HAVE TO RUN ALL NOTEBOOKS IN THE `INGEST` SECTION.\")\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ingest_create_athena_table_tsv_passed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not ingest_create_athena_table_tsv_passed:\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] YOU HAVE TO RUN ALL NOTEBOOKS IN THE `INGEST` SECTION.\")\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "else:\n",
    "    print(\"[OK]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_staging_dir = \"s3://{}/athena/staging\".format(bucket)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tsv_prefix = \"amazon-reviews-pds/tsv\"\n",
    "database_name = \"dsoaws\"\n",
    "table_name_tsv = \"amazon_reviews_tsv\"\n",
    "table_name = \"product_reviews\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "statement = \"\"\"\n",
    "CREATE TABLE IF NOT EXISTS {}.{} AS \n",
    "SELECT review_id, review_body \n",
    "FROM {}.{}\n",
    "\"\"\".format(\n",
    "    database_name, table_name, database_name, table_name_tsv\n",
    ")\n",
    "\n",
    "print(statement)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "if region in [\"eu-west-1\", \"ap-south-1\", \"us-east-1\", \"us-west-2\"]:\n",
    "    conn = connect(region_name=region, s3_staging_dir=s3_staging_dir)\n",
    "    pd.read_sql(statement, conn)\n",
    "\n",
    "    print(\"[OK]\")\n",
    "else:\n",
    "    print(\"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\" [ERROR] AthenaML IS *NOT* SUPPORTED IN {} !!\".format(region))\n",
    "    print(\" [INFO] This is OK. SKIP this notebook and move ahead with the workshop.\")\n",
    "    print(\" [INFO] This notebook is not required for the rest of this workshop.\")\n",
    "    print(\"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if region in [\"eu-west-1\", \"ap-south-1\", \"us-east-1\", \"us-west-2\"]:\n",
    "    statement = \"SELECT * FROM {}.{} LIMIT 10\".format(database_name, table_name)\n",
    "    conn = connect(region_name=region, s3_staging_dir=s3_staging_dir)\n",
    "    df_table = pd.read_sql(statement, conn)\n",
    "    print(df_table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add the Required `AmazonAthenaPreviewFunctionality` Work Group to Use This Preview Feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from botocore.exceptions import ClientError\n",
    "\n",
    "client = boto3.client(\"athena\")\n",
    "\n",
    "if region in [\"eu-west-1\", \"ap-south-1\", \"us-east-1\", \"us-west-2\"]:\n",
    "    try:\n",
    "        response = client.create_work_group(Name=\"AmazonAthenaPreviewFunctionality\")\n",
    "        print(response)\n",
    "    except ClientError as e:\n",
    "        if e.response[\"Error\"][\"Code\"] == \"InvalidRequestException\":\n",
    "            print(\"[OK] Workgroup already exists.\")\n",
    "        else:\n",
    "            print(\"[ERROR] {}\".format(e))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create SQL Query"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `USING FUNCTION` clause specifies an ML with Athena (Preview) function or multiple functions that can be referenced by a subsequent `SELECT` statement in the query. You define the function name, variable names, and data types for the variables and return values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "statement = \"\"\"\n",
    "USING FUNCTION predict_star_rating(review_body VARCHAR) \n",
    "    RETURNS VARCHAR TYPE\n",
    "    SAGEMAKER_INVOKE_ENDPOINT WITH (sagemaker_endpoint = '{}'\n",
    ")\n",
    "SELECT review_id, review_body, predict_star_rating(REPLACE(review_body, ',', ' ')) AS predicted_star_rating \n",
    "    FROM {}.{} LIMIT 10\n",
    "    \"\"\".format(\n",
    "    autopilot_endpoint_name, database_name, table_name\n",
    ")\n",
    "\n",
    "print(statement)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Query the Autopilot Endpoint using Data from the Athena Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if region in [\"eu-west-1\", \"ap-south-1\", \"us-east-1\", \"us-west-2\"]:\n",
    "    conn = connect(region_name=region, s3_staging_dir=s3_staging_dir, work_group=\"AmazonAthenaPreviewFunctionality\")\n",
    "    df = pd.read_sql(statement, conn)\n",
    "    print(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Delete Endpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sm = boto3.client(\"sagemaker\")\n",
    "\n",
    "if autopilot_endpoint_name:\n",
    "    sm.delete_endpoint(EndpointName=autopilot_endpoint_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%html\n",
    "\n",
    "<p><b>Shutting down your kernel for this notebook to release resources.</b></p>\n",
    "<button class=\"sm-command-button\" data-commandlinker-command=\"kernelmenu:shutdown\" style=\"display:none;\">Shutdown Kernel</button>\n",
    "        \n",
    "<script>\n",
    "try {\n",
    "    els = document.getElementsByClassName(\"sm-command-button\");\n",
    "    els[0].click();\n",
    "}\n",
    "catch(err) {\n",
    "    // NoOp\n",
    "}    \n",
    "</script>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%javascript\n",
    "\n",
    "try {\n",
    "    Jupyter.notebook.save_checkpoint();\n",
    "    Jupyter.notebook.session.delete();\n",
    "}\n",
    "catch(err) {\n",
    "    // NoOp\n",
    "}"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3 (Data Science)",
   "language": "python",
   "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
